from __future__ import division, print_function
import sys
import copy
import time
import numpy as np
import os.path as osp
import datetime
import warnings
import torch
import torch.nn as nn

import torchreid
from torchreid.utils import (
    Logger, AverageMeter, check_isfile, open_all_layers, save_checkpoint,
    set_random_seed, collect_env_info, open_specified_layers,
    load_pretrained_weights, compute_model_complexity
)
from torchreid.data.transforms import (
    Resize, Compose, ToTensor, Normalize, Random2DTranslation,
    RandomHorizontalFlip
)

import models
import datasets
from default_parser import init_parser, optimizer_kwargs, lr_scheduler_kwargs

parser = init_parser()
args = parser.parse_args()


def init_dataset(use_gpu):
    normalize = Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    transform_tr = Compose(
        [
            Random2DTranslation(args.height, args.width, p=0.5),
            RandomHorizontalFlip(),
            ToTensor(), normalize
        ]
    )

    transform_te = Compose(
        [Resize([args.height, args.width]),
         ToTensor(), normalize]
    )

    trainset = datasets.init_dataset(
        args.dataset,
        root=args.root,
        transform=transform_tr,
        mode='train',
        verbose=True
    )

    valset = datasets.init_dataset(
        args.dataset,
        root=args.root,
        transform=transform_te,
        mode='val',
        verbose=False
    )

    testset = datasets.init_dataset(
        args.dataset,
        root=args.root,
        transform=transform_te,
        mode='test',
        verbose=False
    )

    num_attrs = trainset.num_attrs
    attr_dict = trainset.attr_dict

    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=use_gpu,
        drop_last=True
    )

    valloader = torch.utils.data.DataLoader(
        valset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=use_gpu,
        drop_last=False
    )

    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=use_gpu,
        drop_last=False
    )

    return trainloader, valloader, testloader, num_attrs, attr_dict


def main():
    global args

    set_random_seed(args.seed)
    use_gpu = torch.cuda.is_available() and not args.use_cpu
    log_name = 'test.log' if args.evaluate else 'train.log'
    sys.stdout = Logger(osp.join(args.save_dir, log_name))

    print('** Arguments **')
    arg_keys = list(args.__dict__.keys())
    arg_keys.sort()
    for key in arg_keys:
        print('{}: {}'.format(key, args.__dict__[key]))
    print('\n')
    print('Collecting env info ...')
    print('** System info **\n{}\n'.format(collect_env_info()))

    if use_gpu:
        torch.backends.cudnn.benchmark = True
    else:
        warnings.warn(
            'Currently using CPU, however, GPU is highly recommended'
        )

    dataset_vars = init_dataset(use_gpu)
    trainloader, valloader, testloader, num_attrs, attr_dict = dataset_vars

    if args.weighted_bce:
        print('Use weighted binary cross entropy')
        print('Computing the weights ...')
        bce_weights = torch.zeros(num_attrs, dtype=torch.float)
        for _, attrs, _ in trainloader:
            bce_weights += attrs.sum(0) # sum along the batch dim
        bce_weights /= len(trainloader) * args.batch_size
        print('Sample ratio for each attribute: {}'.format(bce_weights))
        bce_weights = torch.exp(-1 * bce_weights)
        print('BCE weights: {}'.format(bce_weights))
        bce_weights = bce_weights.expand(args.batch_size, num_attrs)
        criterion = nn.BCEWithLogitsLoss(weight=bce_weights)

    else:
        print('Use plain binary cross entropy')
        criterion = nn.BCEWithLogitsLoss()

    print('Building model: {}'.format(args.arch))
    model = models.build_model(
        args.arch,
        num_attrs,
        pretrained=not args.no_pretrained,
        use_gpu=use_gpu
    )
    num_params, flops = compute_model_complexity(
        model, (1, 3, args.height, args.width)
    )
    print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))

    if args.load_weights and check_isfile(args.load_weights):
        load_pretrained_weights(model, args.load_weights)

    if use_gpu:
        model = nn.DataParallel(model).cuda()
        criterion = criterion.cuda()

    if args.evaluate:
        test(model, testloader, attr_dict, use_gpu)
        return

    optimizer = torchreid.optim.build_optimizer(
        model, **optimizer_kwargs(args)
    )
    scheduler = torchreid.optim.build_lr_scheduler(
        optimizer, **lr_scheduler_kwargs(args)
    )

    start_epoch = args.start_epoch
    best_result = -np.inf
    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        best_result = checkpoint['label_mA']
        print('Loaded checkpoint from "{}"'.format(args.resume))
        print('- start epoch: {}'.format(start_epoch))
        print('- label_mA: {}'.format(best_result))

    time_start = time.time()

    for epoch in range(start_epoch, args.max_epoch):
        train(
            epoch, model, criterion, optimizer, scheduler, trainloader, use_gpu
        )
        test_outputs = test(model, testloader, attr_dict, use_gpu)
        label_mA = test_outputs[0]
        is_best = label_mA > best_result
        if is_best:
            best_result = label_mA

        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'epoch': epoch + 1,
                'label_mA': label_mA,
                'optimizer': optimizer.state_dict(),
            },
            args.save_dir,
            is_best=is_best
        )

    elapsed = round(time.time() - time_start)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print('Elapsed {}'.format(elapsed))


def train(epoch, model, criterion, optimizer, scheduler, trainloader, use_gpu):
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    model.train()

    if (epoch + 1) <= args.fixbase_epoch and args.open_layers is not None:
        print(
            '* Only train {} (epoch: {}/{})'.format(
                args.open_layers, epoch + 1, args.fixbase_epoch
            )
        )
        open_specified_layers(model, args.open_layers)
    else:
        open_all_layers(model)

    end = time.time()
    for batch_idx, data in enumerate(trainloader):
        data_time.update(time.time() - end)

        imgs, attrs = data[0], data[1]
        if use_gpu:
            imgs = imgs.cuda()
            attrs = attrs.cuda()

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, attrs)
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)

        losses.update(loss.item(), imgs.size(0))

        if (batch_idx+1) % args.print_freq == 0:
            # estimate remaining time
            num_batches = len(trainloader)
            eta_seconds = batch_time.avg * (
                num_batches - (batch_idx+1) + (args.max_epoch -
                                               (epoch+1)) * num_batches
            )
            eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
            print(
                'Epoch: [{0}/{1}][{2}/{3}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Lr {lr:.6f}\t'
                'Eta {eta}'.format(
                    epoch + 1,
                    args.max_epoch,
                    batch_idx + 1,
                    len(trainloader),
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses,
                    lr=optimizer.param_groups[0]['lr'],
                    eta=eta_str
                )
            )

        end = time.time()

    scheduler.step()


@torch.no_grad()
def test(model, testloader, attr_dict, use_gpu):
    batch_time = AverageMeter()
    model.eval()

    num_persons = 0
    prob_thre = 0.5
    ins_acc = 0
    ins_prec = 0
    ins_rec = 0
    mA_history = {
        'correct_pos': 0,
        'real_pos': 0,
        'correct_neg': 0,
        'real_neg': 0
    }

    print('Testing ...')

    for batch_idx, data in enumerate(testloader):
        imgs, attrs, img_paths = data
        if use_gpu:
            imgs = imgs.cuda()

        end = time.time()
        orig_outputs = model(imgs)
        batch_time.update(time.time() - end)

        orig_outputs = orig_outputs.data.cpu().numpy()
        attrs = attrs.data.numpy()

        # transform raw outputs to attributes (binary codes)
        outputs = copy.deepcopy(orig_outputs)
        outputs[outputs < prob_thre] = 0
        outputs[outputs >= prob_thre] = 1

        # compute label-based metric
        overlaps = outputs * attrs
        mA_history['correct_pos'] += overlaps.sum(0)
        mA_history['real_pos'] += attrs.sum(0)
        inv_overlaps = (1-outputs) * (1-attrs)
        mA_history['correct_neg'] += inv_overlaps.sum(0)
        mA_history['real_neg'] += (1 - attrs).sum(0)

        outputs = outputs.astype(bool)
        attrs = attrs.astype(bool)

        # compute instabce-based accuracy
        intersect = (outputs & attrs).astype(float)
        union = (outputs | attrs).astype(float)
        ins_acc += (intersect.sum(1) / union.sum(1)).sum()
        ins_prec += (intersect.sum(1) / outputs.astype(float).sum(1)).sum()
        ins_rec += (intersect.sum(1) / attrs.astype(float).sum(1)).sum()

        num_persons += imgs.size(0)

        if (batch_idx+1) % args.print_freq == 0:
            print(
                'Processed batch {}/{}'.format(batch_idx + 1, len(testloader))
            )

        if args.save_prediction:
            txtfile = open(osp.join(args.save_dir, 'prediction.txt'), 'a')
            for idx in range(imgs.size(0)):
                img_path = img_paths[idx]
                probs = orig_outputs[idx, :]
                labels = attrs[idx, :]
                txtfile.write('{}\n'.format(img_path))
                txtfile.write('*** Correct prediction ***\n')
                for attr_idx, (label, prob) in enumerate(zip(labels, probs)):
                    if label:
                        attr_name = attr_dict[attr_idx]
                        info = '{}: {:.1%}  '.format(attr_name, prob)
                        txtfile.write(info)
                txtfile.write('\n*** Incorrect prediction ***\n')
                for attr_idx, (label, prob) in enumerate(zip(labels, probs)):
                    if not label and prob > 0.5:
                        attr_name = attr_dict[attr_idx]
                        info = '{}: {:.1%}  '.format(attr_name, prob)
                        txtfile.write(info)
                txtfile.write('\n\n')
            txtfile.close()

    print(
        '=> BatchTime(s)/BatchSize(img): {:.4f}/{}'.format(
            batch_time.avg, args.batch_size
        )
    )

    ins_acc /= num_persons
    ins_prec /= num_persons
    ins_rec /= num_persons
    ins_f1 = (2*ins_prec*ins_rec) / (ins_prec+ins_rec)

    term1 = mA_history['correct_pos'] / mA_history['real_pos']
    term2 = mA_history['correct_neg'] / mA_history['real_neg']
    label_mA_verbose = (term1+term2) * 0.5
    label_mA = label_mA_verbose.mean()

    print('* Results *')
    print('  # test persons: {}'.format(num_persons))
    print('  (instance-based)  accuracy:      {:.1%}'.format(ins_acc))
    print('  (instance-based)  precition:     {:.1%}'.format(ins_prec))
    print('  (instance-based)  recall:        {:.1%}'.format(ins_rec))
    print('  (instance-based)  f1-score:      {:.1%}'.format(ins_f1))
    print('  (label-based)     mean accuracy: {:.1%}'.format(label_mA))
    print('  mA for each attribute: {}'.format(label_mA_verbose))

    return label_mA, ins_acc, ins_prec, ins_rec, ins_f1


if __name__ == '__main__':
    main()
